{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Scannet vizualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import sys\n",
    "import importlib\n",
    "\n",
    "DIR = os.path.dirname(os.getcwd())\n",
    "torch_points3d = os.path.join(DIR, \"torch_points3d\")\n",
    "assert os.path.exists(torch_points3d)\n",
    "\n",
    "MODULE_PATH = os.path.join(torch_points3d, \"__init__.py\")\n",
    "MODULE_NAME = \"torch_points3d\"\n",
    "spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)\n",
    "module = importlib.util.module_from_spec(spec)\n",
    "sys.modules[spec.name] = module\n",
    "spec.loader.exec_module(module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_points3d import datasets\n",
    "from torch_points3d.datasets import object_detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import sys\n",
    "import panel as pn\n",
    "import numpy as np\n",
    "import pyvista as pv\n",
    "pv.set_plot_theme(\"document\")\n",
    "import glob\n",
    "from matplotlib.colors import ListedColormap\n",
    "from omegaconf import OmegaConf\n",
    "import random\n",
    "\n",
    "pn.extension('vtk')\n",
    "os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')\n",
    "os.environ['DISPLAY'] = ':99'\n",
    "os.environ['PYVISTA_OFF_SCREEN'] = 'True'\n",
    "os.environ['PYVISTA_USE_PANEL'] = 'True'\n",
    "\n",
    "DIR = os.path.dirname(os.getcwd())\n",
    "sys.path.append(DIR)\n",
    "\n",
    "from torch_points3d.datasets.object_detection.scannet import ScannetDataset, ScannetObjectDetection\n",
    "from torch_points3d.datasets.segmentation.scannet import Scannet, SCANNET_COLOR_MAP\n",
    "from torch_points3d.datasets.segmentation import IGNORE_LABEL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Scannet dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_options = OmegaConf.load(os.path.join(DIR,'conf/data/object_detection/scannet.yaml'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_options.data.dataroot = os.path.join(DIR,\"data\")\n",
    "dataset = ScannetDataset(dataset_options.data)\n",
    "dataset.train_dataset.transform = None\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualise the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = dataset.train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d.size_class_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d.sem_cls_label "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_random_data(event):\n",
    "    i = np.random.randint(0, len(dataset.train_dataset))\n",
    "    sample = dataset.train_dataset[i]\n",
    "    pl = pv.Plotter(notebook=True)\n",
    "    \n",
    "    # Color by points with a label\n",
    "    mask = sample.vote_label_mask\n",
    "    pl.add_points(sample.pos[mask == True].numpy(), color=\"blue\") \n",
    "    pl.add_points(sample.pos[mask == False].numpy(), color=\"grey\", opacity=0.75) \n",
    "    \n",
    "    # Color by points with a label\n",
    "    centres = sample.center_label[sample.box_label_mask].numpy()\n",
    "#     pl.add_points(centres,color=\"red\", point_size=10.)\n",
    "    \n",
    "    # Bounding boxes\n",
    "    labels = sample.sem_cls_label[sample.box_label_mask]\n",
    "    box_size = sample.size_residual_label[sample.box_label_mask].numpy() + dataset.train_dataset.MEAN_SIZE_ARR[labels]\n",
    "    for i, centre in enumerate(centres):\n",
    "        box = pv.Box((centre[0] - box_size[i][0] / 2, centre[0] + box_size[i][0] / 2,\n",
    "                     centre[1] - box_size[i][1] / 2, centre[1] + box_size[i][1] / 2,\n",
    "                     centre[2] - box_size[i][2] / 2, centre[2] + box_size[i][2] / 2))\n",
    "        label = dataset.train_dataset.NYU40IDS[sample.sem_cls_label[i].item()]\n",
    "        color = np.asarray(SCANNET_COLOR_MAP[label]) / 255.\n",
    "        pl.add_mesh(box, color=color, show_edges=True, opacity=0.5)\n",
    "\n",
    "    pan.object = pl.ren_win"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pl = pv.Plotter(notebook=True)\n",
    "pan = pn.panel(pl.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "button = pn.widgets.Button(name='Load new model', button_type='primary')\n",
    "button.on_click(load_random_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dashboard = pn.Row(\n",
    "    pn.Column('## Scannet vizualise',button),\n",
    "    pan\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
